Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option not to abort on cuda OOM #1110

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

WilliamTambellini
Copy link
Contributor

Warning: Not ready for merge.
Add option not to abort on cuda OOM but return a ggml_status.
The goal is NOT to be able to continue decoding when OOM but just to do a clean controlled exit at higher level.
Needs cmake GGML_NO_ABORT_ON_OOM=ON (default OFF)
Retouch ggml_tallocr_alloc to return a ggml_status. Retouch init_tensor to return a ggml_status.
Add a bool option for ggml_cuda_error() to abort or not, default true. Add a new macro CUDA_CHECK_NO_ABORT()
Ass a new unit test to check the GGML_NO_ABORT_ON_OOM flow.

Warning: Not ready for merge.
Add option not to abort on cuda OOM but return a ggml_status.
The goal is NOT to be able to continue decoding when OOM but just to do
a clean controlled exit at higher level.
Needs cmake GGML_NO_ABORT_ON_OOM=ON (default OFF)
Retouch ggml_tallocr_alloc to return a ggml_status.
Retouch init_tensor to return a ggml_status.
Add a bool option for ggml_cuda_error() to abort or not, default true.
Add a new macro CUDA_CHECK_NO_ABORT()
Ass a new unit test to check the GGML_NO_ABORT_ON_OOM flow.
Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not convinced about the approach in the CUDA backend. It will require a lot of changes to change every single CUDA_CHECK. I would consider changing CUDA_CHECK to throw an exception instead, and catching them in the ggml-backend functions. The ggml-backend functions must never leak exceptions, so consider adding a noexcept to all the ggml-backend interface functions when building from C++. This will also require ensuring that every resource is allocated via RAII in an exception-safe manner.

@@ -19,7 +19,7 @@ struct ggml_tallocr {
};

GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
GGML_API void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
GGML_API enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is necessary to change this function, since it does not allocate any memory itself. All errors from this function can be prevented by ensuring that the buffer has enough space.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, reverted

@@ -150,6 +155,7 @@ static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offs
}
#endif

// Check with reviewer: could that function returns a ggm_status (offset being an arg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function also does not allocate any (physical) memory, it is just calculating offsets within a buffer. If it fails, it means there is a bug somewhere else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but see it can still abort.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The abort is mostly a sanity check, it cannot happen if everything is working as expect. If it fails, it means there is a serious bug in ggml.

Comment on lines +681 to +682
// Returns true on success, false otherwise
// Check with reviewers: any cons to return a ggml_status?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be ok to change the gallocr functions to return a ggml_status.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, retouching that PR.

@@ -44,7 +44,7 @@ extern "C" {
// base address of the buffer
void * (*get_base) (ggml_backend_buffer_t buffer);
// (optional) initialize a tensor in the buffer (eg. add tensor extras)
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
enum ggml_status (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All backends that use this function will need to be updated. It would be preferable to open the PR in llama.cpp since it has much better CI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. To move forward step after step, would you accept a PR in llamacpp with just that init_tensor change?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

@@ -79,18 +79,19 @@

#define GGML_CUDA_MAX_STREAMS 8

[[noreturn]]
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
// Print the error. Will also abort if abort true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that the abort parameter is necessary. The cuBLAS functions may also allocate memory and fail (CUBLAS_CHECK).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would propose to keep the abort bool option since it s up to the developper to decide to allow abort or not.
For cublas, I could add a CUBLAS_CHECK_NO_ABORT() if you d like me too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned below, I do not think there is any case where aborting on a CUDA call failure is acceptable. We must allow applications to deal with these errors, we can't just make their applications disappear without explanation when something unexpected happens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, at least we agree that the abort is/was too brutal.
I have introduced the abort bool in order to make a difference between cuda failures that are today aborting and oom failures that are aborting too (as today) but for which we dont want to.
At the moment our goal is just to catch ooms, not to handle and forward upward all cuda failures (oom or not).
So you propose to extend the scope of that PR to all cuda failures, right?

Copy link
Member

@slaren slaren Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary to extend the scope of the PR, you can leave the aborts on functions that don't have a way to return an error, like the buffer functions. However you will still need to catch the exceptions and turn them into a GGML_ABORT. In the future we can extended the ggml API to return errors in more conditions. Adding an abort parameter is just going to add a lot of changes that will need to be reverted in the future anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hum, so something like

try {
   CUDA_CHECK(dosomething());
} catch(std::exception) { GGML_ABORT(); }

would be a nightmare as there are hundred of CUDA_CHECK calls in ggml-cuda.cu.

Would nt it be simpler to add the throw in CUDA_CHECK_GEN

#define CUDA_CHECK_GEN(err, success, error_fn)                                      \
     do {                                                                           \
        auto err_ = (err);                                                          \
        if (err_ != (success)) {                                                    \
            ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_));    \
        }                                                                           \
       throw (err == oom ? std::bad_alloc(...) : std::runtime_error(...));
    } while (0)

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need a try..catch block for every CUDA_CHECK, only one for each ggml-backend interface function. For example:

static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) try {
    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;

    ggml_cuda_set_device(ctx->device);
    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}
catch (const std::exception & e) {
    GGML_ABORT("%s", e.what());
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it was an easy refactor, we would have already done it. If you add an abort parameter to every CUDA_CHECK, you will be adding to the work that will need to be done in the future.

@@ -1681,6 +1681,7 @@ void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes) {
}

struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
GGML_ASSERT(src);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not check for NULL pointers with GGML_ASSERT, there are way too many cases and it would bloat massively the code to add so many checks. If necessary use assert instead so that it is only enabled in debug builds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the simple C assert() is not that useful but better than nothing. Done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also expect some consistency though, what the reason for only randomly checking for NULL pointers here? Nearly every ggml function takes a pointer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, was just playing with these ops. I ll remove the asserts, I dont want to extend the scope of that PR further.

@leok7v
Copy link

leok7v commented Feb 12, 2025

Same goes for Metal on MacOS. At the moment metal is usable on some pre Apple Silicon Macs (amazingly actually) and on some it just crashes or hangs allocating kernels. Maybe not worth an effort though

@WilliamTambellini
Copy link
Contributor Author

Tks @slaren
Please reply to these questions before I retouch this PR and/or prepare the PR in llama.cpp.

I am not convinced about the approach in the CUDA backend. It will require a lot of changes to change every single CUDA_CHECK. I would consider changing CUDA_CHECK to throw an exception instead, and catching them in the ggml-backend functions. The ggml-backend functions must never leak exceptions, so consider adding a noexcept to all the ggml-backend interface functions when building from C++. This will also require ensuring that every resource is allocated via RAII in an exception-safe manner.

I m ok to retouch the way you prefer it but please be precise in order to save time for everybody.
Do you propose something like:

void ggml_cuda_error(...) {
  ...
  if (THROW_ON_CUDA_ERROR)
    throw std::runtime_error(...);
  else
    GGML_ABORT(...);
}

?
Tks

@graehl
Copy link

graehl commented Feb 12, 2025

Since a backend cannot allow exceptions to escape (to allow dynamic linking at the backend API boundary), it seems we only have the question of which backend functions need modification to allow non-abort handling of allocation failures. Is it really only ggml_tallocr_alloc and init_tensor, @WilliamTambellini?
I suggest documenting in the new API that the backend is still usable after such a return (i.e. immediately retrying with smaller input is permissible).
Clearly there are many aborts in ggml generally (and the cuda backend arguably) that should stay as aborts in this PR.
@slaren is correct that provisions to clean up on OOM return (formerly abort) would need to be verified and/or created before the cuda backend can correctly return an OOM error instead of aborting.
slaren's idea that exceptions could be used internal to the cuda backend makes sense (and that this implies that cuda backend internal cleanup provisions are RAII+exception safe).
I'm not sure about slaren's suggestion of making every CUDA_CHECK throw an exception, as this increases the number of code paths that need to be made exception safe, considering the overall purpose of this PR is just to support orderly error return for OOM conditions. But if ggml veterans think that there will be additional recoverable-error conditions then this more general approach could be a decent investment.

@slaren
Copy link
Member

slaren commented Feb 12, 2025

Yes, that's what I am proposing. Throw an exception in the CHECK macros in case of failure, and catch them in the ggml-backend functions that can fail to return an error to the caller.

@slaren
Copy link
Member

slaren commented Feb 12, 2025

I'm not sure about slaren's suggestion of making every CUDA_CHECK throw an exception, as this increases the number of code paths that need to be made exception safe, considering the overall purpose of this PR is just to support orderly error return for OOM conditions. But if ggml veterans think that there will be additional recoverable-error conditions then this more general approach could be a decent investment.

My opinion is that we should only abort when some pre-condition that is expected to be met by the caller is not. These are programming errors that indicate that ggml is not being used correctly, and usually can be fixed easily. However we should never crash the application just because a CUDA function returns an error - we must always provide applications some way to recover from this, or at least give it a chance to shut down cleanly.

@WilliamTambellini
Copy link
Contributor Author

Yes, that's what I am proposing. Throw an exception in the CHECK macros in case of failure, and catch them in the ggml-backend functions that can fail to return an error to the caller.

hum, your reply is confusing: my question was about retouching the ggml_cuda_error() fn but you are speaking about the CHECK macro.
Again, please lets be precise: are you speaking about:
CUDA_CHECK_GEN(...) ?
CUDA_CHECK(...) ?
Tks

@slaren
Copy link
Member

slaren commented Feb 12, 2025

I mentioned the CHECK macros because that's what the code uses to check CUDA calls, ggml_cuda_error is just an implementation detail of the CHECK macros. All the CHECK macros call ggml_cuda_error, it's ok to throw the exception from it.

@graehl
Copy link

graehl commented Feb 12, 2025

Ok, the scope of "all CUDA errors" makes sense, and @slaren is of course correct that inserting a throw in the CUDA_CHECK macros to be caught at or before the backend API level (along with making all users of the macro and their callers exception-safe) would achieve this.

@WilliamTambellini
Copy link
Contributor Author

@slaren we are perhaps moving forward although I would prefer not to extend the scope of that PR to all cuda errors. Now, if you tell me precisely what you prefer, perhaps I could do it.
Do you propose something like:

void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
    int id = -1; // in case cudaGetDevice fails
    (void)cudaGetDevice(&id);

    GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
    GGML_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
    GGML_LOG_ERROR("  %s\n", stmt);
#ifndef __CUDA_ARCH__
    throw std::runtime_error(msg);
#endif
}

?

@slaren
Copy link
Member

slaren commented Feb 12, 2025

Yes. To summarize:

  • Throw exceptions on CUDA error
  • Catch them in all the ggml-backend interface functions that call CUDA functions that may fail
  • In cases where returning an error from the function is currently not possible without significant changes to the ggml-backend interface, make it abort anyway after catching the exception

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants